import json
from celery.decorators import task
from .models import *
import os, requests
import re
from urllib.parse import urlparse
from lib.config_json import *
from .models import filter_data
from celery.decorators import task
import time
import random

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def is_json(value):
    try:
        json_object = json.loads(value)
    except ValueError:
        return "normal"
    return "json"


def modify_auth(tp="", parm="", host="", modify=""):
    results = filter_data.objects.filter(host__icontains=host)
    # 修改数据包，一般是APP的sign值或着token值
    if tp == "token":
        for i in results:
            data = i.data_packet
            new_data = ""
            status = is_json(data.split('\n')[-1])
            method = data.split('\n')[0].split(' ')[0]
            url = data.split('\n')[0].split(' ')[1]
            if method == "GET":
                keys = url.split('?')[-1].split('&')
                for j in keys:
                    if j.split('=')[0] == parm:
                        value = j.split('=')[-1]
                        new_data = data.replace(value, modify)
                        break
            elif method == 'POST':
                keys = url.split('?')[-1].split('&')
                for j in keys:
                    if j.split('=')[0] == parm:
                        value = j.split('=')[-1]
                        t = 1
                        break
                    else:
                        t = 0

                if t == 0:
                    # token=asdasdas&id=1
                    if status == "normal":
                        values = data.split('\n')[-1].split('&')
                        for j in values:
                            if parm == j.split('=')[0]:
                                new_data = data.replace(j.split('=')[-1], modify)
                                break

                    # {"token":"asdasdas","id":1}
                    elif status == "json":
                        json_data = json.loads(data.split('\n')[-1])
                        for j in json_data:
                            if parm == j:
                                new_data = data.replace(json_data[j], modify)
                                print(new_data)
                                break
                elif t == 1:
                    new_data = data.replace(value, modify)
            if len(new_data) > 0:
                i.data_packet = new_data
                i.save()



    # 修改普通网站的web系统cookie值
    elif tp == "cookie":
        for i in results:
            data = re.sub("Cookie:(.*)", "Cookie:" + modify, i.data_packet, re.IGNORECASE)
            i.data_packet = data
            i.save()
    else:
        pass


# 过滤重复的数据包
def filter():
    datas = proxy_data.objects.all()
    for data in datas:
        if data.method == 'GET':
            tmp = ''
            try:
                url = re.match('\S+\?', data.url).group()
                parms = data.url.split(url)[-1].split('&')
                for parm in parms:
                    tmp = parm.split('=')[0] + ',' + tmp
                record = len(filter_data.objects.filter(url=data.url.split('?')[0]))
                if record != 0:
                    parm = filter_data.objects.get(url=data.url.split('?')[0]).parm
                    tmp_list = tmp.split(',')
                    parm_list = parm.split(',')
                    Union = list(set(tmp_list).union(set(parm_list)))  # 列表的并集
                    if len(Union) != len(parm_list):
                        packet = data.method + ' ' + data.url + ' ' + data.http_version + '\n' + data.request_headers
                        filter_data.objects.create(url=data.url, data_packet=packet, parm=tmp)
                else:
                    packet = data.method + ' ' + data.url + ' ' + data.http_version + '\n' + data.request_headers
                    filter_data.objects.create(
                        host=data.host,
                        url=data.url.split('?')[0],
                        data_packet=packet,
                        parm=tmp)
            except AttributeError:
                pass

        elif data.method == 'POST':
            url = data.url
            record = len(filter_data.objects.filter(url=url))
            packet = data.method + ' ' + data.url + ' ' + data.http_version + '\n' + data.request_headers + '\n\n' + data.request_content
            post_data = data.request_content
            json_data = re.findall('^{\S+}$', post_data)
            if record == 0:  # 如果数据库没有记录，写入数据
                if len(data.request_content) > 0:  # post数据为0
                    if len(json_data) > 0:  # post为json方式
                        json_list = re.findall('"(.*?)":.*?,', json_data[0])
                        json_str = ",".join(json_list)
                        filter_data.objects.create(host=data.host, url=url, data_packet=packet, parm=json_str)
                    else:  # post为正常方式，loginDate=2017-10-25&searchVal=&loginStatus=all&pageIndex=1&pageSize=10
                        parm = ''
                        parms = data.request_content.split('&')
                        for i in parms:
                            parm = i.split('=')[0] + ',' + parm
                        filter_data.objects.create(host=data.host, url=url, parm=parm, data_packet=packet)
                else:
                    filter_data.objects.create(host=data.host, url=url, parm='', data_packet=packet)
            else:  # 数据库有记录，如果记录重复，忽略
                if len(data.request_content) > 0:  # post数据为0
                    if len(json_data) > 0:  # post为json方式
                        json_list = re.findall('"(.*?)":.*?,', json_data[0])
                        data_list = filter_data.objects.filter(url=url)[0].parm.split(',')
                        Union = list(set(json_list).union(set(data_list)))  # 列表的并集
                        if len(Union) != len(data_list):
                            filter_data.objects.create(host=data.host, url=data.url, data_packet=packet,
                                                       parm=",".join(json_list))

                    else:  # post为正常方式，loginDate=2017-10-25&searchVal=&loginStatus=all&pageIndex=1&pageSize=10
                        parm = ''
                        parms = data.request_content.split('&')
                        data_list = filter_data.objects.filter(url=url)[0].parm.split(',')
                        for i in parms:
                            parm = i.split('=')[0] + ',' + parm
                        parm_list = parm.split(',')
                        Union = list(set(parm_list).union(set(data_list)))  # 列表的并集
                        if len(Union) != len(data_list):
                            filter_data.objects.create(host=data.host, url=data.url, data_packet=packet, parm=parm)


class autosqli(object):
    def __init__(self):
        self.headers = {'Content-Type': 'application/json'}
        self.data = {}

    def del_taskid(self):
        url = SQLMAP_API_SERVER + '/task/' + self.taskid + '/delete'

    def list_options(self):
        url = self.SQLMAP_API_SERVER + '/option/' + self.taskid + '/list'

    def get_options(self):
        url = self.SQLMAP_API_SERVER + '/option/' + self.taskid + '/get'

    def start_scan(self, result,status):
        if status == "start":
            if len(result.taskid) == 0:
                server = SQLMAP_API_SERVER[random.randint(0, len(SQLMAP_API_SERVER) - 1)]
                url = server + "/task/new"
                taskid = requests.get(url).json()['taskid']
                options = {
                    'paramExclude': SQLMAP_PARMEXCLUDE,
                    'threads': SQLMAP_THREADS,
                    'level': SQLMAP_LEVEL,
                    'risk': SQLMAP_RISK,
                    'dbms': SQLMAP_DBMS,
                    'requestFile': SQLMAP_REQUESTFILE_PATH + taskid,
                    'flushSession': True,
                    'retries': SQLMAP_RETRIES,
                    'proxy': SQLMAP_PROXY,
                    'verbose': SQLMAP_VERBOSE,
                }
                url = server + "/option/" + taskid + "/set"
                data = json.dumps(options)
                r = requests.post(url, data=data, headers=self.headers)

                try:
                    fp = open(SQLMAP_REQUESTFILE_PATH + taskid, 'w+')
                    fp.writelines(str(result.data_packet) + '\n')
                    fp.close()
                except UnicodeEncodeError:
                    pass
                host = urlparse(result.url).netloc
                url = server + "/scan/" + taskid + "/start"
                requests.post(url, data=json.dumps(self.data), headers=self.headers)
                result.status = '1'
                result.taskid = taskid
                result.save()

                #插入注入的数据
                r1 = inject_data.objects.get_or_create(taskid=taskid)
                r1[0].packet = result.data_packet
                r1[0].host = host
                r1[0].url = result.url
                r1[0].api_address = server
                r1[0].save()


        elif status == "restart":
            server = result.api_address
            url = server + "/task/new"
            taskid = requests.get(url).json()['taskid']
            options = {
                'paramExclude': SQLMAP_PARMEXCLUDE,
                'threads': SQLMAP_THREADS,
                'level': SQLMAP_LEVEL,
                'risk': SQLMAP_RISK,
                'dbms': SQLMAP_DBMS,
                'requestFile': SQLMAP_REQUESTFILE_PATH + taskid,
                'flushSession': True,
                'retries': SQLMAP_RETRIES,
                'proxy': SQLMAP_PROXY,
                'verbose': SQLMAP_VERBOSE,
            }
            url = server + "/option/" + taskid + "/set"
            data = json.dumps(options)
            r = requests.post(url, data=data, headers=self.headers)

            try:
                fp = open(SQLMAP_REQUESTFILE_PATH + taskid, 'w+')
                fp.writelines(str(result.packet) + '\n')
                fp.close()
            except UnicodeEncodeError:
                pass
            url = server + "/scan/" + taskid + "/start"
            requests.post(url, data=json.dumps(self.data), headers=self.headers)
            result.log_status = ""
            result.run_status = "running"
            result.parameter = ""
            result.taskid = taskid
            result.status = '0'

            result.save()

    def stop_scan(self):
        url = self.server + "/scan/" + self.taskid + "/stop"

    def status_scan(self):
        url = self.server + '/scan/' + self.taskid + '/status'

    def kill_scan(self):
        url = self.server + '/scan/' + self.taskid + '/kill'

    def scanstatus(self, taskid, server):
        url = server + '/scan/' + taskid + '/log'
        r = requests.get(url=url, headers=self.headers)
        logs = r.json()['log']
        log = ""
        for i in logs:
            log = log + "[" + i['level'] + "]" + "[" + i['time'] + "]" + i['message'] + "\n"
        sl = inject_data.objects.get(taskid=taskid)
        sl.log = log
        print("sl:"+sl.log)
        sl.save()
        # sqlmap_log.objects.get_or_create(taskid=taskid)
        # l = sqlmap_log.objects.get(taskid=taskid)
        # l.log = log
        # l.save()
        if logs[-1]['level'] == "CRITICAL" and "connection timed out to the target URL" in logs[-1]['message']:
            return "连接超时"
        elif logs[-1]['level'] == 'INFO':
            return "成功"
        elif logs[-1]['level'] == "CRITICAL" and "all tested parameters do not appear to be injectable" in logs[-1][
            'message']:
            return "失败"
        elif logs[-1]['level'] == "WARNING" and "403" in logs[-1]['message']:
            return "403"
        elif logs[-1]['level'] == "CRITICAL" and "connection dropped or unknown HTTP status code received" in logs[-1][
            'message']:
            return "connection dropped"
        elif logs[-1]['level'] == "WARNING" and "Internal Server Error" in logs[-1]['message']:
            return "500 error"
        elif logs[-1]['level'] == "CRITICAL" and "connection reset to the target" in logs[-1]['message']:
            return "connection reset to the target URL or proxy"
        elif logs[-1]['level'] == "CRITICAL" and "unable to connect to" in logs[-1]['message']:
            return "unable to connect to the target URL ('No route to host')"
        elif logs[-1]['level'] == "WARNING" and "404 (Not Found)" in logs[-1]['message']:
            return "404"
        elif logs[-1]['level'] == "CRITICAL" and "does not exist" in logs[-1]['message']:
            return "does not exist"

        else:
            return ""


    def sqlmaplog(self, taskid, server):
        url = server + '/scan/' + taskid + '/log'
        r = requests.get(url=url, headers=self.headers)
        logs = r.json()['log']
        log = ""
        for i in logs:
            log = log + "[" + i['level'] + "]" + "[" + i['time'] + "]" + i['message'] + "\n"
        return log

    def runstatus(self, result):
        if result.status == '0':
            url = result.api_address + '/scan/' + result.taskid + '/status'
            try:
                run_status = requests.get(url).json()['status']
            except requests.exceptions.ConnectionError:
                run_status = 'timeout'
            except KeyError:
                run_status = "keyerror"
            parameter = ''
            if run_status == 'terminated':

                # SQLMAP结束的状态
                log_staus = self.scanstatus(result.taskid, result.api_address)
                sqlmap_log =self.sqlmaplog(result.taskid, result.api_address)

                if result.status == '0':
                    url = result.api_address + '/scan/' + result.taskid + '/data'
                    data = requests.get(url).json()['data']
                    if len(data) > 0:
                        status = data[0]['status']
                        url1 = data[0]['value']['url']
                        for i in data[1]['value']:
                            dbms = i['dbms']
                            parameter = i['parameter'] + ',' + parameter
                        result.url = url1
                        result.dbms = dbms
                        result.vul_info = '漏洞存在'
                        result.parameter = parameter
                        result.status = '2'
                        result.run_status = run_status
                        result.log_status = log_staus
                        result.log = sqlmap_log
                        result.save()
                    else:
                        result.status = '1'
                        result.vul_info = '漏洞不存在'
                        result.run_status = run_status
                        result.log_status = log_staus
                        result.log = sqlmap_log
                        result.save()
            elif run_status == 'running':
                result.run_status = run_status
                result.save()
            elif run_status == 'timeout':
                result.run_status = run_status
                result.save()
            elif run_status == 'keyerror':
                result.run_status = run_status
                result.save()


# @task
def get_runstatus(results):
    sqli = autosqli()
    for result in results:
        sqli.runstatus(result)


@task
def run_sqlmap(id,status):
    limit = int(SQLMAP_LIMIT_RUN) + 4
    run_limit = int(os.popen("ps aux | grep sqlmap | grep -v grep | wc -l").readline().replace('\n', ''))

    while (run_limit > limit):
        print("waitting")
        time.sleep(10)
    sqli = autosqli()
    if status == "start":
        result = filter_data.objects.get(id=id)
        sqli.start_scan(result,status)
    elif status == "restart":
        result = inject_data.objects.get(id=id)
        sqli.start_scan(result,status)


